{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7yuytuIllsv1"
      },
      "source": [
        "# A Conceptual, Practical Introduction to Trax Layers\n",
        "\n",
        "This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:\n",
        "\n",
        "  1. **Layers**: the basic building blocks and how to combine them into networks\n",
        "  1. **Data Streams**: how individual layers manage inputs and outputs\n",
        "  1. **Data Stack**: how the Trax runtime manages data streams for the layers\n",
        "  1. **Defining New Layer Classes**: how to define and test your own layer classes\n",
        "  1. **Models**: how to train, evaluate, and run predictions with Trax models\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BIl27504La0G"
      },
      "source": [
        "## General Setup\n",
        "Execute the following few cells (once) before running any of the code samples in this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "oILRLCWN_16u"
      },
      "outputs": [],
      "source": [
        "# Copyright 2018 Google LLC.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "\n",
        "import numpy as onp  # np used below for trax.math.numpy\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {
          "height": 51
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 244,
          "status": "ok",
          "timestamp": 1576195193158,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 0
        },
        "id": "vlGjGoGMTt-D",
        "outputId": "5d6ac72b-30cc-44fe-a81f-d62d4f026080"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "/bin/sh: pip: command not found\n",
            "/bin/sh: pip: command not found\n"
          ]
        }
      ],
      "source": [
        "# Import Trax\n",
        "\n",
        "! pip install -q -U trax\n",
        "! pip install -q tensorflow\n",
        "\n",
        "from trax import math\n",
        "from trax import layers as tl\n",
        "from trax import shapes\n",
        "from trax.math import numpy as np  # For use in defining new layer types.\n",
        "from trax.shapes import ShapeDtype\n",
        "from trax.shapes import signature"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "bYWNWL9MJHv9"
      },
      "outputs": [],
      "source": [
        "# Settings and utilities for handling inputs, outputs, and object properties.\n",
        "\n",
        "onp.set_printoptions(precision=3)  # Reduce visual noise from extra digits.\n",
        "\n",
        "def show_layer_properties(layer_obj, layer_name):\n",
        "  template = ('{}.n_in:  {}\\n'\n",
        "              '{}.n_out: {}\\n'\n",
        "              '{}.sublayers: {}\\n'\n",
        "              '{}.weights:    {}\\n')\n",
        "  print(template.format(layer_name, layer_obj.n_in,\n",
        "                        layer_name, layer_obj.n_out,\n",
        "                        layer_name, layer_obj.sublayers,\n",
        "                        layer_name, layer_obj.weights))\n",
        "\n",
        "def floats_in_range(start, end):\n",
        "  return onp.arange(start, end).astype(onp.float32)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-LQ89rFFsEdk"
      },
      "source": [
        "# 1. Layers\n",
        "\n",
        "The Layer class represents Trax's basic building blocks:\n",
        "```\n",
        "  \"\"\"Base class for composable layers in a deep learning network.\n",
        "\n",
        "  Layers are the basic building blocks for deep learning models. A Trax layer\n",
        "  computes a function from zero or more inputs to zero or more outputs,\n",
        "  optionally using trainable weights (common) and non-parameter state (not\n",
        "  common). Authors of new layer subclasses typically override at most two\n",
        "  methods of the base `Layer` class:\n",
        "\n",
        "    forward(inputs, weights):\n",
        "      Computes this layer's output as part of a forward pass through the model.\n",
        "\n",
        "    new_weights(self, input_signature):\n",
        "      Returns new weights suitable for inputs with the given signature.\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LyLVtdxorDPO"
      },
      "source": [
        "## A layer computes a function."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ntZ4_eNQldzL"
      },
      "source": [
        "A layer computes a function from zero or more inputs to zero or more outputs. The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.\n",
        "\n",
        "The simplest layers, those with no weights or sublayers, can be used without initialization. You can think of them (and test them) like simple mathematical functions. For ease of testing and interactive exploration, layer\n",
        "objects implement the `__call__ ` method, so you can call them directly on input data:\n",
        "```\n",
        "y = my_layer(x)\n",
        "```\n",
        "\n",
        "Layers are also objects, so you can inspect their properties. For example:\n",
        "```\n",
        "print('Number of inputs expected by this layer: {}'.format(my_layer.n_in))\n",
        "```\n",
        "\n",
        "### Example 1. tl.Relu $[n_{in} = 1, n_{out} = 1]$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 221
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 706,
          "status": "ok",
          "timestamp": 1576195194212,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 0
        },
        "id": "V09viOSEQvQe",
        "outputId": "21043d74-a9b9-488c-cfe5-290f4bb47e99"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x:\n",
            "[[-7. -6. -5. -4. -3.]\n",
            " [-2. -1.  0.  1.  2.]\n",
            " [ 3.  4.  5.  6.  7.]]\n",
            "\n",
            "relu(x):\n",
            "[[0. 0. 0. 0. 0.]\n",
            " [0. 0. 0. 1. 2.]\n",
            " [3. 4. 5. 6. 7.]]\n",
            "\n",
            "number of inputs expected by this layer: 1\n",
            "number of outputs promised by this layer: 1\n"
          ]
        }
      ],
      "source": [
        "relu = tl.Relu()\n",
        "\n",
        "x = floats_in_range(-7, 8).reshape(3, -1)\n",
        "y = relu(x)\n",
        "\n",
        "# Show input, output, and two layer properties.\n",
        "template = ('x:\\n{}\\n\\n'\n",
        "            'relu(x):\\n{}\\n\\n'\n",
        "            'number of inputs expected by this layer: {}\\n'\n",
        "            'number of outputs promised by this layer: {}')\n",
        "print(template.format(x, y, relu.n_in, relu.n_out))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7sYxIT8crFVE"
      },
      "source": [
        "### Example 2. tl.Concatenate $[n_{in} = 2, n_{out} = 1]$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 442
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 379,
          "status": "ok",
          "timestamp": 1576195194608,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 0
        },
        "id": "LMPPNWXLoOZI",
        "outputId": "e802d00d-9d22-4d0d-eb08-0086317f0ccf"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x1:\n",
            "[[-7. -6. -5. -4. -3.]\n",
            " [-2. -1.  0.  1.  2.]\n",
            " [ 3.  4.  5.  6.  7.]]\n",
            "\n",
            "x2:\n",
            "[[-70. -60. -50. -40. -30.]\n",
            " [-20. -10.   0.  10.  20.]\n",
            " [ 30.  40.  50.  60.  70.]]\n",
            "\n",
            "concat_axis_0([x1, x2]):\n",
            "[[ -7.  -6.  -5.  -4.  -3.]\n",
            " [ -2.  -1.   0.   1.   2.]\n",
            " [  3.   4.   5.   6.   7.]\n",
            " [-70. -60. -50. -40. -30.]\n",
            " [-20. -10.   0.  10.  20.]\n",
            " [ 30.  40.  50.  60.  70.]]\n",
            "\n",
            "concat_axis_1([x1, x2]):\n",
            "[[ -7.  -6.  -5.  -4.  -3. -70. -60. -50. -40. -30.]\n",
            " [ -2.  -1.   0.   1.   2. -20. -10.   0.  10.  20.]\n",
            " [  3.   4.   5.   6.   7.  30.  40.  50.  60.  70.]]\n",
            "\n",
            "concat_axis_0: Concatenate{in=2,out=1}\n",
            "concat_axis_1: Concatenate{in=2,out=1}\n"
          ]
        }
      ],
      "source": [
        "concat_axis_0 = tl.Concatenate(axis=0)\n",
        "concat_axis_1 = tl.Concatenate(axis=1)\n",
        "\n",
        "x1 = floats_in_range(-7, 8).reshape(3, -1)\n",
        "x2 = x1 * 10.\n",
        "y0 = concat_axis_0([x1, x2])\n",
        "y1 = concat_axis_1([x1, x2])\n",
        "\n",
        "template = ('x1:\\n{}\\n\\n'\n",
        "            'x2:\\n{}\\n\\n'\n",
        "            'concat_axis_0([x1, x2]):\\n{}\\n\\n'\n",
        "            'concat_axis_1([x1, x2]):\\n{}\\n')\n",
        "print(template.format(x1, x2, y0, y1))\n",
        "\n",
        "# Print abbreviated object representations (useful for debugging).\n",
        "print('concat_axis_0: {}'.format(concat_axis_0))\n",
        "print('concat_axis_1: {}'.format(concat_axis_1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1oZv3R8bRMvF"
      },
      "source": [
        "## Layers are trainable."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Yyu56aUubyxr"
      },
      "source": [
        "Most layer types include weights that affect the computation of outputs from inputs, and they use back-progagated gradients to update those weights.\n",
        "\n",
        "🚧🚧 *A very small subset of layer types, such as `BatchNorm`, also include weights (called `state`) that are updated based on forward-pass inputs/computation rather than back-propagated gradients.*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3d64M7wLryji"
      },
      "source": [
        "### Initialization\n",
        "\n",
        "Trainable layers must be initialized before use. Trax's model trainers take care of this as part of the overall training process. In other settings (e.g., in tests or interactively in a Colab notebook), you need to initialize the *outermost/topmost* layer explicitly. For this, use `init`:\n",
        "\n",
        "```\n",
        "  def init(self, input_signature, rng=None):\n",
        "    \"\"\"Initializes this layer and its sublayers recursively.\n",
        "\n",
        "    This method is designed to initialize each layer instance once, even if the\n",
        "    same layer instance occurs in multiple places in the network. This enables\n",
        "    weight sharing to be implemented as layer sharing.\n",
        "\n",
        "    Args:\n",
        "      input_signature: A `ShapeDtype` instance (if this layer takes one input)\n",
        "          or a list/tuple of `ShapeDtype` instances.\n",
        "      rng: A single-use random number generator (JAX PRNG key). If none is\n",
        "          provided, a default rng based on the integer seed 0 will be used.\n",
        "\n",
        "    Returns:\n",
        "      A (weights, state) tuple, in which weights contains newly created weights\n",
        "          on the first call and `EMPTY_WEIGHTS` on all subsequent calls.\n",
        "    \"\"\"\n",
        "```\n",
        "\n",
        "Input signatures can be built from scratch using `ShapeDType` objects, or can\n",
        "be derived from data via the `signature` function:\n",
        "```\n",
        "def signature(obj):\n",
        "  \"\"\"Returns a `ShapeDtype` signature for the given `obj`.\n",
        "\n",
        "  A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`\n",
        "  instances. Note that this function is permissive with respect to its inputs\n",
        "  (accepts lists or tuples, and underlying objects can be any type as long as\n",
        "  they have shape and dtype attributes), but strict with respect to its outputs\n",
        "  (only `ShapeDtype`, and only tuples).\n",
        "\n",
        "  Args:\n",
        "    obj: An object that has `shape` and `dtype` attributes, or a list/tuple\n",
        "        of such objects.\n",
        "  \"\"\"\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yL8HAj6GEAp1"
      },
      "source": [
        "### Example 3. tl.LayerNorm $[n_{in} = 1, n_{out} = 1]$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 221
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 872,
          "status": "ok",
          "timestamp": 1576195195496,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 0
        },
        "id": "Ie7iyX91qAx2",
        "outputId": "d101eeec-1861-410b-a3e5-3d6c0701c03b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x:\n",
            "[[-7. -6. -5. -4. -3.]\n",
            " [-2. -1.  0.  1.  2.]\n",
            " [ 3.  4.  5.  6.  7.]]\n",
            "\n",
            "layer_norm(x):\n",
            "[[-1.414 -0.707  0.     0.707  1.414]\n",
            " [-1.414 -0.707  0.     0.707  1.414]\n",
            " [-1.414 -0.707  0.     0.707  1.414]]\n",
            "\n",
            "layer_norm.weights:\n",
            "(_FilledConstant([1., 1., 1., 1., 1.], dtype=float32), _FilledConstant([0., 0., 0., 0., 0.], dtype=float32))\n"
          ]
        }
      ],
      "source": [
        "layer_norm = tl.LayerNorm()\n",
        "\n",
        "x = floats_in_range(-7, 8).reshape(3, -1)\n",
        "layer_norm.init(signature(x))\n",
        "\n",
        "y = layer_norm(x)\n",
        "\n",
        "template = ('x:\\n{}\\n\\n'\n",
        "            'layer_norm(x):\\n{}\\n')\n",
        "print(template.format(x, y))\n",
        "print('layer_norm.weights:\\n{}'.format(layer_norm.weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZWZUXEJAofH-"
      },
      "source": [
        "## Layers combine into layers."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "d47gVdGV1vWw"
      },
      "source": [
        "The Trax library authors encourage users to build new layers as combinations of existing layers. Hence, the library provides a small set of _combinator_ layers: layer objects that make a list of layers behave as a single layer.\n",
        "\n",
        "The new layer, like other layers, can:\n",
        "* compute outputs from inputs,\n",
        "* update parameters from gradients, and\n",
        "* combine with yet more layers."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vC1ymG2j0iyp"
      },
      "source": [
        "### Combine with `Serial`\n",
        "\n",
        "The most common way to combine layers is with the `Serial` class:\n",
        "```\n",
        "class Serial(base.Layer):\n",
        "  \"\"\"Combinator that applies layers serially (by function composition).\n",
        "\n",
        "  A Serial combinator uses stack semantics to manage data for its sublayers.\n",
        "  Each sublayer sees only the inputs it needs and returns only the outputs it\n",
        "  has generated. The sublayers interact via the data stack. For instance, a\n",
        "  sublayer k, following sublayer j, gets called with the data stack in the\n",
        "  state left after layer j has applied. The Serial combinator then:\n",
        "\n",
        "    - takes n_in items off the top of the stack (n_in = k.n_in) and calls\n",
        "      layer k, passing those items as arguments; and\n",
        "\n",
        "    - takes layer k's n_out return values (n_out = k.n_out) and pushes\n",
        "      them onto the data stack.\n",
        "\n",
        "  ...\n",
        "```\n",
        "If one layer has the same number of outputs as the next layer has inputs (which is the usual case), the successive layers behave like function composition:\n",
        "\n",
        "```\n",
        "#  h(.) = g(f(.))\n",
        "layer_h = Serial(\n",
        "    layer_f,\n",
        "    layer_g,\n",
        ")\n",
        "```\n",
        "Note how, inside `Serial`, function composition is expressed naturally as a succession of operations, so that no nested parentheses are needed.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "uPOnrDa9ViPi"
      },
      "source": [
        "### Example 4. y = layer_norm(relu(x)) $[n_{in} = 1, n_{out} = 1]$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 170
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 2832,
          "status": "ok",
          "timestamp": 1576195198343,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 0
        },
        "id": "dW5fpusjvjmh",
        "outputId": "fe5e2bdd-1263-4599-c320-75ccee492360"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x:\n",
            "[[-7. -6. -5. -4. -3.]\n",
            " [-2. -1.  0.  1.  2.]\n",
            " [ 3.  4.  5.  6.  7.]]\n",
            "\n",
            "layer_block(x):\n",
            "[[ 0.     0.     0.     0.     0.   ]\n",
            " [-0.75  -0.75  -0.75   0.5    1.75 ]\n",
            " [-1.414 -0.707  0.     0.707  1.414]]\n"
          ]
        }
      ],
      "source": [
        "layer_block = tl.Serial(\n",
        "    tl.Relu(),\n",
        "    tl.LayerNorm(),\n",
        ")\n",
        "\n",
        "x = floats_in_range(-7, 8).reshape(3, -1)\n",
        "layer_block.init(signature(x))\n",
        "y = layer_block(x)\n",
        "\n",
        "template = ('x:\\n{}\\n\\n'\n",
        "            'layer_block(x):\\n{}')\n",
        "print(template.format(x, y,))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bRtmN6ckQO1q"
      },
      "source": [
        "And we can inspect the block as a whole, as if it were just another layer:\n",
        "\n",
        "### Example 4'. Inspecting a Serial layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 102
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 260,
          "status": "ok",
          "timestamp": 1576195198618,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 0
        },
        "id": "D6BpYddZQ1eu",
        "outputId": "5293c446-f3b3-4fca-eda3-7e4cf11beae4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "layer_block:\n",
            "Serial{in=1,out=1,sublayers=[Relu{in=1,out=1}, LayerNorm{in=1,out=1}]}\n",
            "\n",
            "layer_block.weights:\n",
            "[(), (_FilledConstant([1., 1., 1., 1., 1.], dtype=float32), _FilledConstant([0., 0., 0., 0., 0.], dtype=float32))]\n"
          ]
        }
      ],
      "source": [
        "print('layer_block:\\n{}\\n'.format(layer_block))\n",
        "\n",
        "print('layer_block.weights:\\n{}'.format(layer_block.weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kJ8bpYZtE66x"
      },
      "source": [
        "### Combine with `Branch`\n",
        "The `Branch` combinator arranges layers into parallel computational channels:\n",
        "```\n",
        "def Branch(*layers):\n",
        "  \"\"\"Combinator that applies a list of layers in parallel to copies of inputs.\n",
        "\n",
        "  Each layer in the input list is applied to as many inputs from the stack\n",
        "  as it needs, and their outputs are successively combined on stack.\n",
        "\n",
        "  For example, suppose one has three layers:\n",
        "\n",
        "    - F: 1 input, 1 output\n",
        "    - G: 3 inputs, 1 output\n",
        "    - H: 2 inputs, 2 outputs (h1, h2)\n",
        "\n",
        "  Then Branch(F, G, H) will take 3 inputs and give 4 outputs:\n",
        "\n",
        "    - inputs: a, b, c\n",
        "    - outputs: F(a), G(a, b, c), h1, h2    where h1, h2 = H(a, b)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RlPcnRtdIVgq"
      },
      "source": [
        "Residual blocks, for example, are implemented using `Branch`:\n",
        "```\n",
        "def Residual(*layers, **kwargs):\n",
        "  \"\"\"Wraps a series of layers with a residual connection.\n",
        "\n",
        "  Args:\n",
        "    *layers: One or more layers, to be applied in series.\n",
        "    **kwargs: If empty (the usual case), the Residual layer computes the\n",
        "        element-wise sum of the stack-top input with the output of the layer\n",
        "        series. If non-empty, the only key should be 'shortcut', whose value is\n",
        "        a layer that applies to a copy of the inputs and (elementwise) adds its\n",
        "        output to the output from the main layer series.\n",
        "\n",
        "  Returns:\n",
        "      A layer representing a residual connection paired with a layer series.\n",
        "  \"\"\"\n",
        "  shortcut = kwargs.get('shortcut')  # default None signals no-op (copy inputs)\n",
        "  layers = _ensure_flat(layers)\n",
        "  layer = layers[0] if len(layers) == 1 else Serial(layers)\n",
        "  return Serial(\n",
        "      Branch(shortcut, layer),\n",
        "      Add(),  # pylint: disable=no-value-for-parameter\n",
        "  )\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ruX4aFMdUOwS"
      },
      "source": [
        "Here's a simple code example to highlight the mechanics."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JGGnKjg4ESIg"
      },
      "source": [
        "### Example 5. Branch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 306
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 276,
          "status": "ok",
          "timestamp": 1576195198915,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 0
        },
        "id": "lw6A2YwuW-Ul",
        "outputId": "cd0c481f-fae9-4577-f143-2141eeffd82f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x:\n",
            "[[-7. -6. -5. -4. -3.]\n",
            " [-2. -1.  0.  1.  2.]\n",
            " [ 3.  4.  5.  6.  7.]]\n",
            "\n",
            "y1:\n",
            "[[0. 0. 0. 0. 0.]\n",
            " [0. 0. 0. 1. 2.]\n",
            " [3. 4. 5. 6. 7.]]\n",
            "\n",
            "y2:\n",
            "[[-70. -60. -50. -40. -30.]\n",
            " [-20. -10.   0.  10.  20.]\n",
            " [ 30.  40.  50.  60.  70.]]\n",
            "\n",
            "number of inputs expected by this layer: 1\n",
            "number of outputs promised by this layer: 2\n"
          ]
        }
      ],
      "source": [
        "relu = tl.Relu()\n",
        "times_10 = tl.Fn(lambda x: x * 10.0)\n",
        "branch_relu_t10 = tl.Branch(relu, times_10)\n",
        "\n",
        "x = floats_in_range(-7, 8).reshape(3, -1)\n",
        "branch_relu_t10.init(signature(x))\n",
        "\n",
        "y1, y2 = branch_relu_t10(x)\n",
        "\n",
        "# Show input, outputs, and two layer properties.\n",
        "template = ('x:\\n{}\\n\\n'\n",
        "            'y1:\\n{}\\n\\n'\n",
        "            'y2:\\n{}\\n\\n'\n",
        "            'number of inputs expected by this layer: {}\\n'\n",
        "            'number of outputs promised by this layer: {}')\n",
        "print(template.format(x, y1, y2, branch_relu_t10.n_in, branch_relu_t10.n_out))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RObDHVC3fkzR"
      },
      "source": [
        "# 2. Data Streams"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zr2ZZ1vO8T8V"
      },
      "source": [
        "The Trax runtime supports the concept of multiple data streams, which gives individual layers flexibility to:\n",
        "  - process a single data stream ($n_{in} = n_{out} = 1$),\n",
        "  - process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, ... $),\n",
        "  - split data streams ($n_{in} \u003c n_{out}$), or\n",
        "  - merge data streams ($n_{in} \u003e n_{out}$).\n",
        "\n",
        "We saw in section 1 the example of `Residual`, which involves both a split and a merge:\n",
        "```\n",
        "  ...\n",
        "  return Serial(\n",
        "      Branch(shortcut, layer),\n",
        "      Add(),\n",
        "  )\n",
        "```\n",
        "In other words, layer by layer:\n",
        "  - `Branch(shortcut, layers)`: makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a no-op), and processes the other copy via the given layers, applied in series. [$n_{in} = 1$, $n_{out} = 2$]\n",
        "  - `Add()`: combines the two streams back into one by adding elementwise. [$n_{in} = 2$, $n_{out} = 1$]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QQVo6vhPgO9x"
      },
      "source": [
        "# 3. Data Stack"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "65ite-671cTT"
      },
      "source": [
        "# 4. Defining New Layer Classes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hHSaD9H6hDTf"
      },
      "source": [
        "## Simpler layers, with the `@layer` decorator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cqM6WJwNhoHI"
      },
      "source": [
        "## Full subclass definitions, where necessary"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "llAH3cdE1UeU"
      },
      "source": [
        "# 5. Models"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "A Conceptual, Practical Introduction to Trax Layers",
      "provenance": [
        {
          "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt",
          "timestamp": 1569980697572
        },
        {
          "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5",
          "timestamp": 1563927451951
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
